{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "66575b61-30a5-4471-9e4e-a45c6fc396a9",
   "metadata": {},
   "source": [
    "# Implement RRT and its variant on UR5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b9c224-8a28-4554-b0ee-a8ac303aa6df",
   "metadata": {},
   "outputs": [],
   "source": [
    "import magic_donotload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5d79912-1a64-4466-8017-70724567b28c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import example_robot_data as robex\n",
    "import hppfcl\n",
    "import math\n",
    "import numpy as np\n",
    "import pinocchio as pin\n",
    "import time\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8abf514d-0c36-4c32-934f-e1d013c57ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt; plt.ion()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf038376-9d69-48e9-9bdb-0bcf3b641247",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.meshcat_viewer_wrapper import MeshcatVisualizer, colors\n",
    "from utils.datastructures.storage import Storage\n",
    "from utils.datastructures.pathtree import PathTree\n",
    "from utils.datastructures.mtree import MTree\n",
    "from tp4.collision_wrapper import CollisionWrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58561974-a3d2-4692-821c-30f4d0caf96f",
   "metadata": {},
   "source": [
    "## Load UR5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7217ea2c-5cac-430e-8496-93e80e78b859",
   "metadata": {},
   "outputs": [],
   "source": [
    "robot = robex.load('ur5')\n",
    "collision_model = robot.collision_model\n",
    "visual_model = robot.visual_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98805b5f-f993-4d4a-98ca-c465ff363424",
   "metadata": {},
   "source": [
    "Recall some placement for the UR5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f521701-7416-4736-9363-305a45258d14",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = robot.placement(robot.q0, 6)  # Placement of the end effector joint.\n",
    "b = robot.framePlacement(robot.q0, 22)  # Placement of the end effector tip.\n",
    "\n",
    "tool_axis = b.rotation[:, 2]  # Axis of the tool\n",
    "tool_position = b.translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d96626e3-1b22-446f-916a-f7ee54b78f04",
   "metadata": {},
   "outputs": [],
   "source": [
    "viz = MeshcatVisualizer(robot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d1d4ee-c1f4-4350-9370-ed8f32dee4cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "viz.viewer.jupyter_cell()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ede57d4-8b0d-4a69-adbd-95901db503cd",
   "metadata": {},
   "source": [
    "Set a start and a goal configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b14d4cd-9e44-4e14-9a56-ec08256d0238",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_i = np.array([1., -1.5, 2.1, -.5, -.5, 0])\n",
    "q_g = np.array([3., -1., 1, -.5, -.5, 0])\n",
    "radius = 0.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a05303fd-5fa9-47bf-92cd-63d64e2212cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "viz.display(q_i)\n",
    "M = robot.framePlacement(q_i, 22)\n",
    "name = \"world/sph_initial\"\n",
    "viz.addSphere(name, radius, [0., 1., 0., 1.])\n",
    "viz.applyConfiguration(name,M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbc882d4-b335-4255-9db3-3aa61eeb8ee7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "viz.display(q_g)\n",
    "M = robot.framePlacement(q_g, 22)\n",
    "name = \"world/sph_goal\"\n",
    "viz.addSphere(name, radius, [0., 0., 1., 1.])\n",
    "viz.applyConfiguration(name,M)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a4bce9b-0c81-4c7a-bee9-02c9aa59b829",
   "metadata": {},
   "source": [
    "## Implement everything needed for RRT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d4c45e0-ee85-47b1-88ae-94c3c8fb50d8",
   "metadata": {},
   "source": [
    "We abstract the robot the environment and its behaviour in a class call `System`\n",
    "\n",
    "It must be able to:\n",
    "- generate random configuration which are not colliding if needed (sampling)\n",
    "- implement a distance on the configuration space (distance)\n",
    "- generate path between two configuration (steering)\n",
    "- check if a path is free between two configuration and return the latest free config (directional free steering)\n",
    "and some function to display the configuration.\n",
    "\n",
    "Recall that in the case of the UR5 the configuration space is $S_1^{6}$, where $S_1$ is the unit cirle, we can parametrize by $\\theta\\in[-\\pi,\\pi]$ such that $-\\pi$ and $\\pi$ are identified.\n",
    "\n",
    "In the next cell, you must implement the system behaviour for the UR5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b74e5455-eae8-49fe-9030-f403467df288",
   "metadata": {},
   "outputs": [],
   "source": [
    "class System():\n",
    "\n",
    "    def __init__(self, robot):\n",
    "        self.robot = robot\n",
    "        robot.gmodel = robot.collision_model\n",
    "        self.display_edge_count = 0\n",
    "        self.colwrap = CollisionWrapper(robot)  # For collision checking\n",
    "        self.nq = self.robot.nq\n",
    "        self.display_count = 0\n",
    "        \n",
    "    def distance(self, q1, q2):\n",
    "        \"\"\"\n",
    "        Must return a distance between q1 and q2 which can be a batch of config.\n",
    "        \"\"\"\n",
    "        if len(q2.shape) > len(q1.shape):\n",
    "            q1 = q1[None, ...]\n",
    "        e = np.mod(np.abs(q1 - q2), 2 * np.pi)\n",
    "        e[e > np.pi] = 2 * np.pi - e[e > np.pi]\n",
    "        return np.linalg.norm(e, axis=-1)\n",
    "\n",
    "    def random_config(self, free=True):\n",
    "        \"\"\"\n",
    "        Must return a random configuration which is not in collision if free=True\n",
    "        \"\"\"\n",
    "        q = 2 * np.pi * np.random.rand(6) - np.pi\n",
    "        if not free:\n",
    "            return q\n",
    "        while self.is_colliding(q):\n",
    "            q = 2 * np.pi * np.random.rand(6) - np.pi\n",
    "        return q\n",
    "\n",
    "    def is_colliding(self, q):\n",
    "        \"\"\"\n",
    "        Use CollisionWrapper to decide if a configuration is in collision\n",
    "        \"\"\"\n",
    "        self.colwrap.computeCollisions(q)\n",
    "        collisions = self.colwrap.getCollisionList()\n",
    "        return (len(collisions) > 0)\n",
    "\n",
    "    def get_path(self, q1, q2, l_min=None, l_max=None, eps=0.2):\n",
    "        \"\"\"\n",
    "        generate a continuous path with precision eps between q1 and q2\n",
    "        If l_min of l_max is mention, extrapolate or cut the path such\n",
    "        that \n",
    "        \"\"\"\n",
    "        q1 = np.mod(q1 + np.pi, 2 * np.pi) - np.pi\n",
    "        q2 = np.mod(q2 + np.pi, 2 * np.pi) - np.pi\n",
    "\n",
    "        diff = q2 - q1\n",
    "        query = np.abs(diff) > np.pi\n",
    "        q2[query] = q2[query] - np.sign(diff[query]) * 2 * np.pi\n",
    "\n",
    "        d = self.distance(q1, q2)\n",
    "        if d < eps:\n",
    "            return np.stack([q1, q2], axis=0)\n",
    "        \n",
    "        if l_min is not None or l_max is not None:\n",
    "            new_d = np.clip(d, l_min, l_max)\n",
    "        else:\n",
    "            new_d = d\n",
    "            \n",
    "        N = int(new_d / eps + 2)\n",
    "\n",
    "        return np.linspace(q1, q1 + (q2 - q1) * new_d / d, N)\n",
    "        \n",
    "    def is_free_path(self, q1, q2, l_min=0.2, l_max=1., eps=0.2):\n",
    "        \"\"\"\n",
    "        Create a path and check collision to return the last\n",
    "         non-colliding configuration. Return X, q where X is a boolean\n",
    "        who state is the steering has work.\n",
    "        We require at least l_min must be cover without collision to validate the path.\n",
    "        \"\"\"\n",
    "        q_path = self.get_path(q1, q2, l_min, l_max, eps)\n",
    "        N = len(q_path)\n",
    "        N_min = N - 1 if l_min is None else min(N - 1, int(l_min / eps))\n",
    "        for i in range(N):\n",
    "            if self.is_colliding(q_path[i]):\n",
    "                break\n",
    "        if i < N_min:\n",
    "            return False, None\n",
    "        if i == N - 1:\n",
    "            return True, q_path[-1]\n",
    "        return True, q_path[i - 1]\n",
    "\n",
    "    def reset(self):\n",
    "        \"\"\"\n",
    "        Reset the system visualization\n",
    "        \"\"\"\n",
    "        for i in range(self.display_count):\n",
    "            viz.delete(f\"world/sph{i}\")\n",
    "            viz.delete(f\"world/cil{i}\")\n",
    "        self.display_count = 0\n",
    "    \n",
    "    def display_edge(self, q1, q2, radius=0.01, color=[1.,0.,0.,1]):\n",
    "        M1 = self.robot.framePlacement(q1, 22)  # Placement of the end effector tip.\n",
    "        M2 = self.robot.framePlacement(q2, 22)  # Placement of the end effector tip.\n",
    "        middle = .5 * (M1.translation + M2.translation)\n",
    "        direction = M2.translation - M1.translation\n",
    "        length = np.linalg.norm(direction)\n",
    "        dire = direction / length\n",
    "        orth = np.cross(dire, np.array([0, 0, 1]))\n",
    "        orth /= np.linalg.norm(orth)\n",
    "        orth2 = np.cross(dire, orth)\n",
    "        orth2 /= np.linalg.norm(orth2)\n",
    "        Mcyl = pin.SE3(np.stack([orth2, dire, orth], axis=1), middle)\n",
    "        name = f\"world/sph{self.display_count}\"\n",
    "        viz.addSphere(name, radius, [1.,0.,0.,1])\n",
    "        viz.applyConfiguration(name,M2)\n",
    "        name = f\"world/cil{self.display_count}\"\n",
    "        viz.addCylinder(name, length, radius / 4, [0., 1., 0., 1])\n",
    "        viz.applyConfiguration(name,Mcyl)\n",
    "        self.display_count +=1\n",
    "        \n",
    "    def display_motion(self, qs, step=1e-1):\n",
    "        # Given a point path display the smooth movement\n",
    "        for i in range(len(qs) - 1):\n",
    "            for q in self.get_path(qs[i], qs[i+1])[:-1]:\n",
    "                viz.display(q)\n",
    "                time.sleep(step)\n",
    "        viz.display(qs[-1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51178d60-0474-4d57-9336-d8f1617b8a4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "system = System(robot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b0ced1c-2211-459a-aeb3-fb9c3b25085d",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.distance(q_i, q_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fafd71f-87f0-44b3-a92c-48f53b6580ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.display_motion(system.get_path(q_i, q_g))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914d0701-bc8f-46c0-8888-8e019f15cfc3",
   "metadata": {},
   "source": [
    "## RRT implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a64cb4c-28b0-4e66-9324-1cba47374186",
   "metadata": {},
   "source": [
    "In its most simple form, RRT construct a tree from the start, eventually with a bias toward the goal. In the following class, we add some memoization to avoid recomputing distances. The kNN (k Nearest Neighbors) structure works on node indices."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b1a4b31-07d2-463f-851e-cd4d2ca09bbf",
   "metadata": {},
   "source": [
    "Let us look at an implementation the core algorithm:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62a504d-9ac1-4bcc-a149-cd6516f90a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RRT():\n",
    "    \"\"\"\n",
    "    Can be splited into RRT base because different rrt\n",
    "    have factorisable logic\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        system,\n",
    "        node_max=500000,\n",
    "        iter_max=1000000,\n",
    "        N_bias=10,\n",
    "        l_min=.2,\n",
    "        l_max=.5,\n",
    "        steer_delta=.1,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        [Here, in proper code, we would document the parameters of our function. Do that below,\n",
    "        using the Google style for docstrings.]\n",
    "        https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html\n",
    "\n",
    "        Args:\n",
    "            node_max: ...\n",
    "            iter_max: ...\n",
    "            ...\n",
    "        \"\"\"\n",
    "        self.system = system\n",
    "        # params\n",
    "        self.l_max = l_max\n",
    "        self.l_min = l_min\n",
    "        self.N_bias = N_bias\n",
    "        self.node_max = node_max\n",
    "        self.iter_max = iter_max\n",
    "        self.steer_delta = steer_delta\n",
    "        # intern\n",
    "        self.NNtree = None\n",
    "        self.storage = None\n",
    "        self.pathtree = None\n",
    "        # The distance function will be called on N, dim object\n",
    "        self.real_distance = self.system.distance\n",
    "        # Internal for computational_opti in calculating distance\n",
    "        self._candidate = None\n",
    "        self._goal = None\n",
    "        self._cached_dist_to_candidate = {}\n",
    "        self._cached_dist_to_goal = {}\n",
    "\n",
    "    def distance(self, q1_idx, q2_idx):\n",
    "        if isinstance(q2_idx, int):\n",
    "            if q1_idx == q2_idx:\n",
    "                return 0.\n",
    "            if q1_idx == -1 or q2_idx == -1:\n",
    "                if q2_idx == -1:\n",
    "                    q1_idx, q2_idx = q2_idx, q1_idx\n",
    "                if q2_idx not in self._cached_dist_to_candidate:\n",
    "                    self._cached_dist_to_candidate[q2_idx] = self.real_distance(\n",
    "                        self._candidate, self.storage[q2_idx]\n",
    "                    )\n",
    "                return self._cached_dist_to_candidate[q2_idx]\n",
    "            if q1_idx == -2 or q2_idx == -2:\n",
    "                if q2_idx == -2:\n",
    "                    q1_idx, q2_idx = q2_idx, q1_idx\n",
    "                if q2_idx not in self._cached_dist_to_goal:\n",
    "                    self._cached_dist_to_goal[q2_idx] = self.real_distance(\n",
    "                        self._goal, self.storage[q2_idx]\n",
    "                    )\n",
    "                return self._cached_dist_to_goal[q2_idx]\n",
    "            return self.real_distance(self.storage[q1_idx], self.storage[q2_idx])\n",
    "        if q1_idx == -1:\n",
    "            q = self._candidate\n",
    "        elif q1_idx == -2:\n",
    "            q = self._goal\n",
    "        else:\n",
    "            q = self.storage[q1_idx]\n",
    "        return self.real_distance(q, self.storage[q2_idx])\n",
    "\n",
    "    def new_candidate(self):\n",
    "        q = self.system.random_config(free=True)\n",
    "        self._candidate = q\n",
    "        self._cached_dist_to_candidate = {}\n",
    "        return q\n",
    "\n",
    "    def solve(self, qi, validate, qg=None):\n",
    "        self.system.reset()\n",
    "        self._goal = qg\n",
    "        \n",
    "        # Reset internal datastructures\n",
    "        self.storage = Storage(self.node_max, self.system.nq)\n",
    "        self.pathtree = PathTree(self.storage)\n",
    "        self.NNtree = MTree(self.distance)\n",
    "        qi_idx = self.storage.add_point(qi)\n",
    "        self.NNtree.add_point(qi_idx)\n",
    "        self.it_trace = []\n",
    "\n",
    "        found = False\n",
    "        iterator = range(self.iter_max)\n",
    "        for i in tqdm(iterator):\n",
    "            # New candidate\n",
    "            if i % self.N_bias == 0:\n",
    "                q_new = self._goal\n",
    "                q_new_idx = -2\n",
    "            else:\n",
    "                q_new = self.new_candidate()\n",
    "                q_new_idx = -1\n",
    "\n",
    "            # Find closest neighboor to q_new\n",
    "            q_near_idx, d = self.NNtree.nearest_neighbour(q_new_idx)\n",
    "            \n",
    "            # Steer from it toward the new checking for colision\n",
    "            success, q_prox = self.system.is_free_path(\n",
    "                self.storage.data[q_near_idx],\n",
    "                q_new,\n",
    "                l_min=self.l_min,\n",
    "                l_max=self.l_max,\n",
    "                eps=self.steer_delta\n",
    "            )\n",
    "\n",
    "            if not success:\n",
    "                self.it_trace.append(0)\n",
    "                continue\n",
    "            self.it_trace.append(1)\n",
    "            \n",
    "            # Add the points in data structures\n",
    "            q_prox_idx = self.storage.add_point(q_prox)\n",
    "            self.NNtree.add_point(q_prox_idx)\n",
    "            self.pathtree.update_link(q_prox_idx, q_near_idx)\n",
    "            self.system.display_edge(self.storage[q_near_idx], self.storage[q_prox_idx])\n",
    "\n",
    "            # Test if it reach the goal\n",
    "            if validate(q_prox):\n",
    "                q_g_idx = self.storage.add_point(q_prox)\n",
    "                self.NNtree.add_point(q_g_idx)\n",
    "                self.pathtree.update_link(q_g_idx, q_prox_idx)\n",
    "                found = True\n",
    "                break\n",
    "        self.iter_done = i + 1\n",
    "        self.found = found\n",
    "        return found\n",
    "\n",
    "    def get_path(self, q_g):\n",
    "        assert self.found\n",
    "        path = self.pathtree.get_path()\n",
    "        return np.concatenate([path, q_g[None, :]])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73533b79-9692-4856-8094-b4874e3944ef",
   "metadata": {},
   "source": [
    "In proper code, we would document the parameters of our functions.\n",
    "\n",
    "- **Your turn:** Add docstrings to the code above, following the [Google style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html).\n",
    "- Optional: you are welcome to add type annotations if you'd like.\n",
    "\n",
    "The constructor of the `RRT` class invites you to start."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "811cc5ae-b601-477d-9d55-e560d0e45262",
   "metadata": {},
   "source": [
    "For this problem, we will instantiate our RRT with the following parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "676b991b-32d1-4e2f-8dd9-ed85252e180c",
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt = RRT(\n",
    "    system,\n",
    "    N_bias=20,\n",
    "    l_min=0.2,\n",
    "    l_max=0.5,\n",
    "    steer_delta=0.1,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f20ecc74-a2b6-4b91-8131-cbb7ff3e13a9",
   "metadata": {},
   "source": [
    "Now let's define our termination condition, and run the main function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de20f3cc-6df3-4e5c-ab5b-e183616d4a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps_final = .1\n",
    "def validation(key):\n",
    "    vec = robot.framePlacement(key, 22).translation - robot.framePlacement(q_g, 22).translation\n",
    "    return (float(np.linalg.norm(vec)) < eps_final)\n",
    "\n",
    "rrt.solve(q_i, validation, qg=q_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f3a6a8d-4fd9-419b-8176-214543c08a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.display_motion(rrt.get_path(q_g))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d82012a2-c7c4-4362-85ae-2c8e67a2e2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bd7d92a-abea-482b-8a91-a5942a125585",
   "metadata": {},
   "source": [
    "## Create obstacle with environments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8784c9ad-ab66-4440-9a0f-1a1e10ab4b2d",
   "metadata": {},
   "source": [
    "We already had some simple algorithms to find free paths, *i.e.* without obstacles. Let us now add some obstacles to the environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5efbc0a8-bd55-4122-b09a-168b64f8de19",
   "metadata": {},
   "outputs": [],
   "source": [
    "robot = robex.load('ur5')\n",
    "collision_model = robot.collision_model\n",
    "visual_model = robot.visual_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "189570f8-d592-4f6d-b3d7-89a466fdf898",
   "metadata": {},
   "outputs": [],
   "source": [
    "def addCylinderToUniverse(name, radius, length, placement, color=colors.red):\n",
    "    geom = pin.GeometryObject(\n",
    "        name,\n",
    "        0,\n",
    "        hppfcl.Cylinder(radius, length),\n",
    "        placement\n",
    "    )\n",
    "    new_id = collision_model.addGeometryObject(geom)\n",
    "    geom.meshColor = np.array(color)\n",
    "    visual_model.addGeometryObject(geom)\n",
    "    \n",
    "    for link_id in range(robot.model.nq):\n",
    "        collision_model.addCollisionPair(\n",
    "            pin.CollisionPair(link_id, new_id)\n",
    "        )\n",
    "    return geom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0b531f-992c-47cc-b101-d1113a2f0870",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pinocchio.utils import rotate\n",
    "\n",
    "[collision_model.removeGeometryObject(e.name) for e in collision_model.geometryObjects if e.name.startswith('world/')]\n",
    "\n",
    "# Add a red box in the viewer\n",
    "radius = 0.1\n",
    "length = 1.\n",
    "\n",
    "cylID = \"world/cyl1\"\n",
    "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,0.4,0.5])))\n",
    "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n",
    "\n",
    "\n",
    "cylID = \"world/cyl2\"\n",
    "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,-0.4,0.5])))\n",
    "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n",
    "\n",
    "cylID = \"world/cyl3\"\n",
    "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,0.7,0.5])))\n",
    "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n",
    "\n",
    "\n",
    "cylID = \"world/cyl4\"\n",
    "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,-0.7,0.5])))\n",
    "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6780c139-e2de-4c9f-8af4-61cb1e7e1b4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_i = np.array([-1., -1.5, 2.1, -.5, -.5, 0])\n",
    "q_g = np.array([3.1, -1., 1, -.5, -.5, 0])\n",
    "radius = 0.05"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c114f554-6126-47e9-8749-934d47d2c7c1",
   "metadata": {},
   "source": [
    "We need to reload the viewer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62ab0ec0-5faf-43fc-bcd8-9ad9cbe3ad56",
   "metadata": {},
   "outputs": [],
   "source": [
    "viz = MeshcatVisualizer(robot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "037ac2f5-e0d3-438a-ab10-9f313a8f8803",
   "metadata": {},
   "outputs": [],
   "source": [
    "viz.display(q_i)\n",
    "M = robot.framePlacement(q_i, 22)\n",
    "name = \"world/sph_initial\"\n",
    "viz.addSphere(name, radius, [0., 1., 0., 1.])\n",
    "viz.applyConfiguration(name,M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e5dfe04-e29d-4455-ba00-c7c158a8930c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "viz.display(q_g)\n",
    "M = robot.framePlacement(q_g, 22)\n",
    "name = \"world/sph_goal\"\n",
    "viz.addSphere(name, radius, [0., 0., 1., 1.])\n",
    "viz.applyConfiguration(name,M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f17610e-e874-4e4f-aaf7-6c6d569f965e",
   "metadata": {},
   "outputs": [],
   "source": [
    "viz.display(q_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e19068e5-033b-412d-8d64-ffef4def949a",
   "metadata": {},
   "outputs": [],
   "source": [
    "system = System(robot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "661a6589-36ff-44c5-b761-c6f588a8eab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt = RRT(\n",
    "    system,\n",
    "    N_bias=20,\n",
    "    l_min=0.2,\n",
    "    l_max=0.5,\n",
    "    steer_delta=0.1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffa971ee-6396-4533-b875-5c4e7124bc7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps_final = .1\n",
    "\n",
    "def validation(key):\n",
    "    vec = robot.framePlacement(key, 22).translation - robot.framePlacement(q_g, 22).translation\n",
    "    return (float(np.linalg.norm(vec)) < eps_final)\n",
    "\n",
    "rrt.solve(q_i, validation, qg=q_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf22611e-ba2d-4c0e-93a8-03ff922b6073",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.display_motion(rrt.get_path(q_g))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "160cf161-755f-430e-8e91-e926f644fbe2",
   "metadata": {},
   "source": [
    "And solve RRT. It is long right ? Let us implement more efficient algorithms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c90f1487-cd74-4eb8-8f3f-a0f5e65a1aad",
   "metadata": {},
   "source": [
    "## Bi-RRT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d526af5b-9c37-425f-a30c-45bb80588404",
   "metadata": {},
   "source": [
    "Now it's your turn. Make a `BiRRT` class, similar to the `RRT` class above, but implementing the Bi-RRT algorithm. (It is not recommended to try to inherit from `RRT`, as you will end up re-implementing most functions.) Here is a template you are free to adapt, with some advice:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c003e3fe-a042-4e3a-b644-f607acb9faef",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BiRRT(RRT):\n",
    "    def __init__(\n",
    "        self,\n",
    "        system,\n",
    "        node_max=500000,\n",
    "        iter_max=1000000,\n",
    "        l_min=.2,\n",
    "        l_max=.5,\n",
    "        steer_delta=.1,\n",
    "    ):\n",
    "        # Initialize attributes:\n",
    "        # self.l_min = l_min\n",
    "        # etc.\n",
    "\n",
    "        # New: duplicate this attribute as dictionaries with two keys:\n",
    "        # \"forward\" and \"backward\". See `solve` below.\n",
    "        self._cached_dist_to_candidate = {}\n",
    "        self.storage = {}\n",
    "        self.pathtree = {}\n",
    "        self.tree = {}\n",
    "\n",
    "    def tree_distance(self, direction: str, q1_idx, q2_idx):\n",
    "        # Adapt from RRT.distance\n",
    "        # There is now a direction string to select the underlying tree,\n",
    "        # either \"forward\" (from q_init) or \"backward\" (from q_goal).\n",
    "\n",
    "    def forward_distance(self, q1_idx, q2_idx):\n",
    "        return self.tree_distance(\"forward\", q1_idx, q2_idx)\n",
    "\n",
    "    def backward_distance(self, q1_idx, q2_idx):\n",
    "        return self.tree_distance(\"backward\", q1_idx, q2_idx)\n",
    "\n",
    "    def new_candidate(self):\n",
    "        # A minor change is required to adapt RRT.new_candidate to this template.\n",
    "\n",
    "    def solve(self, qi, validate, qg=None):\n",
    "        # Reset internal datastructures\n",
    "        for direction in (\"forward\", \"backward\"):\n",
    "            self._cached_dist_to_candidate[direction] = {}\n",
    "            self.storage[direction] = Storage(node_max, system.nq)\n",
    "            self.pathtree[direction] = PathTree(self.storage[direction])\n",
    "        self.tree = {\n",
    "            \"forward\": MTree(self.forward_distance),\n",
    "            \"backward\": MTree(self.backward_distance),\n",
    "        }\n",
    "\n",
    "        # Now datastructures are initialized\n",
    "        # The rest is up to you! \n",
    "\n",
    "    def get_path(self):\n",
    "        assert self.found\n",
    "        forward_path = self.pathtree[\"forward\"].get_path()\n",
    "        backward_path = self.pathtree[\"backward\"].get_path()\n",
    "        return np.concatenate([forward_path, backward_path[::-1]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5254b62c-673f-407d-a0e6-effd3e75aabb",
   "metadata": {},
   "source": [
    "You should be able to call `BiRRT` similarly to `RRT`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60443eb2-ea54-4c0b-8d24-1908a4ca710b",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.reset()\n",
    "\n",
    "birrt = BiRRT(\n",
    "    system,\n",
    "    l_min=0.2,\n",
    "    l_max=0.5,\n",
    "    steer_delta=0.1,\n",
    ")\n",
    "\n",
    "birrt.solve(q_i, validation, qg=q_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b62f8275-b4d7-4168-ad6c-e437d0ea2887",
   "metadata": {},
   "outputs": [],
   "source": [
    "system.display_motion(birrt.get_path())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a291e25d-46a4-47b3-b894-8d5810196d39",
   "metadata": {},
   "source": [
    "How many iterations did it take to find a solution? Is it faster than previously with `RRT`?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84285a1d-ac59-40c9-aef1-23527d0bab54",
   "metadata": {},
   "source": [
    "## Bonus question: Bi-RRT*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6212a5aa-5794-48c6-91e5-f70c3ade34fa",
   "metadata": {},
   "source": [
    "Implement an optimal variant `BiRRTStar` of your `BiRRT` class and run it in the same configuration as the two algorithms above. What do you notice about the resulting tree? What is the improvement in overall path length between `RRT`, `BiRRT` and `BiRRTStar`?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
